from io import SEEK_END, SEEK_SET, BytesIO
import math
import io
from s4studio.core import Serializable, ChildElement
from s4studio.helpers import FNV32, Flag, first
from s4studio.io import StreamWriter, StreamReader, RCOL
from s4studio.model import LOD_ID
from s4studio.model.geometry import Mesh, SkinController, BoundingBox, Vertex
from s4studio.model.material import MaterialDefinition, MaterialSet, ScaleOffsets, MaterialBlock


def calculate_scale(extremes, max_values, channels=False, init=False, floor=False):
    if init:
        scales = [1.0 / max_values[0] for i in range(3)]
        scales.append(1.0)
    else:
        scales = [0.0] * 4
    for i, max_value in enumerate(max_values):
        if max_value == float(0x1FF):
            scales[i] = 1 / float(max_value)
            continue
        max_vertex = max([abs(x) for x in (extremes if not channels else extremes[i])])
        if floor:
            max_vertex = math.floor(max_vertex)
        scale = 1 / float(max_value)
        if max_vertex > 1.0:
            val = max_value / max_vertex
            scale = 1.0 / math.floor(val)
            pass
        scales[i] = scale
    return scales


def calculate_pos_scales(bounds, max_value):
    extremes = []
    extremes.extend(bounds.min)
    extremes.extend(bounds.max)
    extremes = [math.ceil(max([abs(x) for x in extremes]))] * 3
    return calculate_scale([[e] for e in extremes], max_value, True, True, True)


def calculate_uv_scales(vertices, max_value):
    if not max_value:
        return None
    vals = [1.0 / (max_value[i] if i < len(max_value) else 0x7FFF) for i in range(3)]
    vals.append(1.0)
    extremes = [0.0] * 10
    for v in vertices:
        assert isinstance(v, Vertex)
        for usage_index, uv in enumerate(v.uv):
            for u in uv:
                extremes[usage_index] = max(abs(u), extremes[usage_index])

    if len(extremes) > len(max_value):
        extremes = extremes[:len(max_value)]
    calc = calculate_scale([[e] for e in extremes], max_value, True, False, False)
    if calc[1] == 0:
        calc[1] = calc[0]
    return calc


class ObjectMesh(Serializable, ChildElement, Mesh):
    class Flags:
        BASIN_INTERIOR = 0x00000001
        HD_EXTERIOR_LIT = 0x00000002
        PORTAL_SIDE = 0x00000004
        DROP_SHADOW = 0x00000008
        SHADOW_CASTER = 0x00000010
        FOUNDATION = 0x00000020
        PICKABLE = 0x00000040

    class PrimitiveType:
        POINT_LIST = 0x00000000
        LINE_LIST = 0x00000001
        LINE_STRIP = 0x00000002
        TRIANGLE_LIST = 0x00000003
        TRIANGLE_FAN = 0x00000004
        TRIANGLE_STRIP = 0x00000005
        QUAD_LIST = 0x00000006
        DISPLAY_LIST = 0x00000007


    class State(Serializable):
        def __init__(self, stream=None):
            self.state = 0
            self.start_index = 0
            self.min_vertex_index = 0
            self.vertex_count = 0
            self.primitive_count = 0
            Serializable.__init__(self, stream)

        def read(self, stream, resource=None):
            s = StreamReader(stream)
            self.state = s.u32()
            self.start_index = s.i32()
            self.min_vertex_index = s.i32()
            self.vertex_count = s.i32()
            self.primitive_count = s.i32()

        def write(self, stream, resource=None):
            s = StreamWriter(stream)
            s.u32(self.state)
            s.i32(self.start_index)
            s.i32(self.min_vertex_index)
            s.i32(self.vertex_count)
            s.i32(self.primitive_count)

    def is_dropshadow(self):
        return isinstance(self.material, MaterialDefinition) and int(self.material.shader_name) == 0xC09C7582

    def __init__(self, stream=None, resources=None, parent=None):
        self.name = None
        self.material = None
        self.vertex_format = None
        self.vertex_buffer = None
        self.index_buffer = None
        self.flags = self.Flags.PICKABLE
        self.primitive_type = self.PrimitiveType.TRIANGLE_LIST
        self.stream_offset = 0
        self.start_vertex = 0
        self.start_index = 0
        self.min_vertex_index = 0
        self.vertex_count = 0
        self.primitive_count = 0
        self.bounds = BoundingBox()
        self.skin_controller = None
        self.bone_references = []
        self.scale_offsets = None
        self.states = []
        self.parent_name = 0
        self.mirror_plane_normal = [0.0, 0.0, 0.0]
        self.mirror_plane_offset = 0.0
        self.unknown = 0
        self.extra_bounds = []
        ChildElement.__init__(self, parent)
        Serializable.__init__(self, stream, resources)

    def get_vertex_format(self):
        return self.vertex_format if self.vertex_format != None else VertexFormat.default_sunshadow() if Flag.is_set(
            self.flags, self.Flags.SHADOW_CASTER) else VertexFormat.default_drop_shadow()


    def read(self, stream, rcol):
        s = StreamReader(stream)
        data_len = s.u32()
        end = stream.tell() + data_len
        self.name = s.u32()
        self.material = rcol.get_block(s.u32(), (MaterialDefinition, MaterialSet))
        self.vertex_format = rcol.get_block(s.u32(), VertexFormat)
        self.vertex_buffer = rcol.get_block(s.u32(), (VertexBuffer, VertexBufferShadow))
        self.index_buffer = rcol.get_block(s.u32(), (IndexBuffer, IndexBufferShadow))
        flags = s.u32()
        self.flags = flags >> 8
        self.primitive_type = flags & 0x000000FF
        self.stream_offset = s.u32()
        self.start_vertex = s.i32()
        self.start_index = s.i32()
        self.min_vertex_index = s.i32()
        self.vertex_count = s.i32()
        self.primitive_count = s.i32()
        self.bounds.read(stream)
        self.skin_controller = rcol.get_block(s.u32(), ObjectSkinController)
        self.bone_references = [s.u32() for i in range(s.i32())]
        self.scale_offsets = rcol.get_block(s.u32(), ScaleOffsets)
        self.states = [self.State(stream) for i in range(s.i32())]
        if self.parent.version > ModelLod.VERSION.DEFAULT:
            self.parent_name = s.u32()
            self.mirror_plane_normal = [s.f32() for i in range(3)]
            self.mirror_plane_offset = s.f32()
        self.unknown = s.u32()
        if self.parent.version >= 0x00000206 and stream.tell() < end:
            c = ((end - stream.tell())/4)/6
            if not c == len(self.bone_references):
                raise Exception("Invalid bounding box count for bones")
            for i in range(len(self.bone_references)):
                b = BoundingBox()
                b.read(stream)
                self.extra_bounds.append(b)
        if not stream.tell() == end: raise Exception(
            "Invalid MLOD.Mesh data length: expected 0x%X, but got 0x%08X" % (end, stream.tell()))

    def write(self, stream, rcol):
        s = StreamWriter(stream)
        len_offset = stream.tell()
        s.u32(0)
        start = stream.tell()
        s.hash(self.name)
        s.u32(rcol.get_block_index(self.material))
        s.u32(rcol.get_block_index(self.vertex_format))
        s.u32(rcol.get_block_index(self.vertex_buffer))
        s.u32(rcol.get_block_index(self.index_buffer))
        flags = self.primitive_type
        flags |= (self.flags << 8)
        s.u32(flags)
        s.u32(self.stream_offset)
        s.i32(self.start_vertex)
        s.i32(self.start_index)
        s.i32(self.min_vertex_index)
        s.i32(self.vertex_count)
        s.i32(self.primitive_count)
        self.bounds.write(stream)
        s.u32(rcol.get_block_index(self.skin_controller))
        s.i32(len(self.bone_references))
        for bone in self.bone_references: s.u32(bone)
        s.u32(rcol.get_block_index(self.scale_offsets))
        s.i32(len(self.states))
        for state in self.states: state.write(stream)
        if self.parent.version > ModelLod.VERSION.DEFAULT:
            s.hash(self.parent_name)
            for i in range(3): s.f32(self.mirror_plane_normal[i])
            s.f32(self.mirror_plane_offset)
        s.u32(self.unknown)
        if self.parent.version >= 0x00000206 and any(self.extra_bounds):
            for b in self.extra_bounds:
                assert isinstance(b,BoundingBox)
                b.write(stream)
        end = stream.tell()
        stream.seek(len_offset, SEEK_SET)
        s.u32(end - start)
        stream.seek(end, SEEK_SET)

    def set_pos_scales(self, pos_scale):
        material = self.scale_offsets
        if material == None: return None
        material.material_block['posscale'] = (pos_scale, MaterialBlock.Item.TYPE.FLOAT)

    def get_pos_scale(self):
        vrtf = self.get_vertex_format()
        assert isinstance(vrtf, VertexFormat)
        max_sizes = vrtf.max_size_for_usage(VertexFormat.USAGE.POSITION)
        val = [1 / v for v in max_sizes]
        val.append(1.0)
        key = 'posscale'
        material = self.scale_offsets
        if material == None: return None
        if key in material.material_block:
            val = material.material_block[key]
        return val

    def get_material_definition(self):
        material = self.material
        if material is None:
            return None
        while not isinstance(material, MaterialDefinition):
            if isinstance(material, MaterialSet):
                if not len(material.elements):
                    return None
                for entry in material.elements:
                    if isinstance(entry.material, MaterialDefinition):
                        material = entry.material
                        break
            else:
                raise Exception("Expected a MaterialDefinition or MaterialSet")
        return material

    def has_uv_scales(self):
        key = 'uvscales'
        material = self.get_material_definition()
        return material and key in material.material_block

    def get_uv_scales(self):
        uvscales = [1 / float(0x7fff)] * 4
        key = 'uvscales'
        material = self.get_material_definition()

        if material and key in material.material_block:
            uvscales = material.material_block[key]
        return uvscales

    def set_uv_scales(self, uvscales):
        material = self.material
        if material is None:
            return
        elif isinstance(material, MaterialDefinition):
            material.material_block['uvscales'] = (uvscales, MaterialBlock.Item.TYPE.FLOAT)
        elif isinstance(material, MaterialSet):
            for entry in material.elements:
                assert isinstance(entry, MaterialSet.Element)
                entry.material.material_block['uvscales'] = (uvscales, MaterialBlock.Item.TYPE.FLOAT)

    def get_vertices(self,state_hash=None):
        uvscales = self.get_uv_scales()
        vrtf = self.get_vertex_format()
        offset = self.stream_offset
        vertex_count = self.vertex_count
        if state_hash is not None:
            for s in self.states:
                assert isinstance(s,self.State)
                if s.state == state_hash:
                    offset = self.stream_offset + (vrtf.stride * s.min_vertex_index)
                    vertex_count = s.vertex_count
                    break
        verts = self.vertex_buffer.buffer.read_vertices(offset, vrtf, vertex_count,
                                                        uvscales)
        return verts


    def get_triangles(self,state_hash=None):
        primitive_size = 0
        if self.primitive_type == self.PrimitiveType.TRIANGLE_LIST:
            primitive_size = 3
        else:
            raise NotImplementedError()
        start_index = self.start_index
        primitive_count = self.primitive_count
        min_vertex=0

        if state_hash is not None:
            for s in self.states:
                assert isinstance(s,self.State)
                if s.state == state_hash:
                    start_index = s.start_index
                    primitive_count = s.primitive_count
                    min_vertex=s.min_vertex_index
                    break
        return [[self.index_buffer.buffer[start_index + (primitive_index * primitive_size) + i]-min_vertex for i in
                 range(primitive_size)] for primitive_index in range(primitive_count)]


class ModelLod(RCOL):
    TAG = 'MLOD'
    ID = 0x01D10F34

    class VERSION:
        DEFAULT = 0x00000201
        EXTENDED = 0x00000202

    def __init__(self, key=None):
        RCOL.__init__(self, key)
        self.version = 0
        self.meshes = []

    def read_rcol(self, stream, rcol):
        self.read_tag(stream)
        s = StreamReader(stream)
        self.version = s.u32()
        self.meshes = []
        for i in range(s.i32()):
            self.meshes.append(ObjectMesh(stream, rcol, self))

    def write_rcol(self, stream, rcol):
        self.write_tag(stream)
        s = StreamWriter(stream)
        s.u32(self.version)
        s.i32(len(self.meshes))
        for mesh in self.meshes: mesh.write(stream, rcol)

class IndexBuffer(RCOL):
    TAG = 'IBUF'
    ID = 0x01D0E70F

    class VERSION:
        DEFAULT = 0x000000100

    class FLAGS:
        DIFFERENCED_INDICES = 0x00000001
        INDEX_32 = 0x00000002
        DISPLAY_LIST = 0x00000004


    def __init__(self, key):
        RCOL.__init__(self, key)
        self.version = self.VERSION.DEFAULT
        self.buffer = []
        self.flags = 0
        self.unknown = 0

    def read_rcol(self, stream, rcol):
        self.read_tag(stream)
        s = StreamReader(stream)
        self.version = s.u32()
        self.flags = s.u32()
        self.unknown = s.u32()
        start = stream.tell()
        stream.seek(0, SEEK_END)
        end = stream.tell()
        stream.seek(start, SEEK_SET)
        self.buffer = []
        last = 0
        use_32 = Flag.is_set(self.flags, self.FLAGS.INDEX_32)
        use_diff = Flag.is_set(self.flags, self.FLAGS.DIFFERENCED_INDICES)
        while stream.tell() < end:
            cur = s.i32() if use_32 else s.i16()
            if use_diff:
                cur += last
                last = cur
            self.buffer.append(cur)


    def write_rcol(self, stream, rcol):
        if len(self.buffer) > 0x7FFF:
            self.flags = Flag.set(self.flags,self.FLAGS.INDEX_32)
        else:
            self.flags = Flag.unset(self.flags,self.FLAGS.INDEX_32)
        self.write_tag(stream)
        s = StreamWriter(stream)
        s.u32(self.version)
        s.u32(self.flags)
        s.u32(self.unknown)
        last = 0
        use_32 = Flag.is_set(self.flags, self.FLAGS.INDEX_32)
        use_diff = Flag.is_set(self.flags, self.FLAGS.DIFFERENCED_INDICES)
        for i in range(len(self.buffer)):
            cur = self.buffer[i]
            if use_diff:
                cur -= last
                last = self.buffer[i]
            s.i32(cur) if use_32 else s.i16(cur)


class VertexFormat(RCOL):
    TAG = 'VRTF'
    ID = 0x01D0E723

    def __init__(self, key=None):
        self.stride = 0
        self.version = self.VERSION.DEFAULT
        self.is_extended_format = False
        self.declarations = []
        RCOL.__init__(self, key)

    def __eq__(self, other):
        return object.__eq__(self, other)

    def __hash__(self):
        return object.__hash__(self)

    def read_rcol(self, stream, rcol):
        self.read_tag(stream)
        s = StreamReader(stream)
        self.version = s.u32()

        self.stride = s.i32()
        cDeclarations = s.i32()
        self.is_extended_format = s.u32() > 0

        self.declarations = []
        for declaration_index in range(cDeclarations):
            declaration = self.Declaration()
            if self.is_extended_format:
                declaration.usage = s.u32()
                declaration.usage_index = s.u32()
                declaration.format = s.u32()
                declaration.offset = s.u32()
            else:
                declaration.usage = s.u8()
                declaration.usage_index = s.u8()
                declaration.format = s.u8()
                declaration.offset = s.u8()
            self.declarations.append(declaration)


    def write_rcol(self, stream, rcol):
        self.write_tag(stream)
        s = StreamWriter(stream)
        s.u32(self.version)
        s.i32(self.stride)
        s.i32(len(self.declarations))
        s.u32(1 if self.is_extended_format else 0)
        for declaration in self.declarations:
            if self.is_extended_format:
                s.u32(declaration.usage)
                s.u32(declaration.usage_index)
                s.u32(declaration.format)
                s.u32(declaration.offset)
            else:
                s.u8(declaration.usage)
                s.u8(declaration.usage_index)
                s.u8(declaration.format)
                s.u8(declaration.offset)


    def add_declaration(self, usage, format):
        declaration = self.Declaration()
        declaration.usage = usage
        declaration.format = format
        declaration.offset = self.stride
        for d in self.declarations:
            if d.usage == usage: declaration.usage_index += 1
        self.declarations.append(declaration)
        self.stride += self.FORMAT.byte_size(format)

    def max_size_for_usage(self, usage):
        if usage == VertexFormat.USAGE.POSITION:
            return [VertexFormat.max_size_for_format(d.format) for d in
                    filter(lambda x: x.usage == usage, self.declarations)] * 3
        return [VertexFormat.max_size_for_format(d.format) for d in
                filter(lambda x: x.usage == usage, self.declarations)]
        pass

    @staticmethod
    def max_size_for_format(format):
        if format in (VertexFormat.FORMAT.SHORT4, VertexFormat.FORMAT.SHORT2, VertexFormat.FORMAT.SHORT2N,
                      VertexFormat.FORMAT.SHORT4N):
            return float(0x7FFF)
        elif format == VertexFormat.FORMAT.USHORT4N:
            return float(0x1FF)
        elif format in (
                VertexFormat.FORMAT.FLOAT, VertexFormat.FORMAT.FLOAT2, VertexFormat.FORMAT.FLOAT3,
                VertexFormat.FORMAT.FLOAT4,
                VertexFormat.FORMAT.FLOAT16_2, VertexFormat.FORMAT.FLOAT16_4):
            return float(0x7FFF)
        else:
            raise Exception('unable to determine max size for format %s' % format)

    @staticmethod
    def from_vertex(vertex):
        vrtf = VertexFormat()
        if not vertex.position == None: vrtf.add_declaration(VertexFormat.USAGE.POSITION, VertexFormat.FORMAT.SHORT4)
        if not vertex.normal == None: vrtf.add_declaration(VertexFormat.USAGE.NORMAL, VertexFormat.FORMAT.COLOR_UBYTE4)
        if not vertex.uv == None:
            for uv_index in range(vertex.uv):
                vrtf.add_declaration(VertexFormat.USAGE.UV, VertexFormat.FORMAT.SHORT2)
        if not vertex.blend_indices == None: vrtf.add_declaration(VertexFormat.USAGE.BLEND_INDEX,
                                                                  VertexFormat.FORMAT.UBYTE4)
        if not vertex.blend_weights == None: vrtf.add_declaration(VertexFormat.USAGE.BLEND_WEIGHT,
                                                                  VertexFormat.FORMAT.COLOR_UBYTE4)

        pass

    @classmethod
    def default_sunshadow(cls):
        vrtf = VertexFormat()
        vrtf.add_declaration(cls.USAGE.POSITION, cls.FORMAT.SHORT4)
        return vrtf

    @classmethod
    def default_drop_shadow(cls):
        vrtf = VertexFormat()
        vrtf.add_declaration(cls.USAGE.POSITION, cls.FORMAT.USHORT4N)
        vrtf.add_declaration(cls.USAGE.UV, cls.FORMAT.SHORT4)
        return vrtf

    @staticmethod
    def default_mesh(mesh):
        return VertexFormat.default_sunshadow() if Flag.is_set(mesh.flags,
                                                               ObjectMesh.Flags.SHADOW_CASTER) else VertexFormat.default_drop_shadow()

    class VERSION:
        DEFAULT = 0x00000002

    class USAGE:
        POSITION = 0x00000000
        NORMAL = 0x00000001
        UV = 0x00000002
        BLEND_INDEX = 0x00000003
        BLEND_WEIGHT = 0x00000004
        TANGENT = 0x00000005
        COLOR = 0x00000006

    class FORMAT:
        UBYTE_MAP = {0: 0, 1: 1, 2: 2, 3: 3}
        FLOAT = 0x00000000
        FLOAT2 = 0x00000001
        FLOAT3 = 0x00000002
        FLOAT4 = 0x00000003
        UBYTE4 = 0x00000004
        COLOR_UBYTE4 = 0x00000005
        SHORT2 = 0x00000006
        SHORT4 = 0x00000007
        UBYTE4N = 0x00000008
        SHORT2N = 0x00000009
        SHORT4N = 0x0000000A
        USHORT2N = 0x0000000B
        USHORT4N = 0x0000000C
        DEC3N = 0x0000000D
        UDEC3N = 0x0000000E
        FLOAT16_2 = 0x0000000F
        FLOAT16_4 = 0x00000010

        @classmethod
        def float_count(cls, f):
            if f == cls.FLOAT:
                return 1
            elif f in (cls.FLOAT2, cls.USHORT2N, cls.SHORT2):
                return 2
            elif f in (cls.SHORT4, cls.SHORT4N, cls.UBYTE4N,
                       cls.USHORT4N,
                       cls.FLOAT3):
                return 3
            elif f in (
                    cls.COLOR_UBYTE4, cls.FLOAT4, cls.UBYTE4):
                return 4
            else:
                raise NotImplementedError()

        @classmethod
        def byte_size(cls, f):
            if f in (
                    cls.FLOAT, cls.UBYTE4, cls.UBYTE4N, cls.COLOR_UBYTE4, cls.SHORT2,
                    cls.USHORT2N):
                return 4
            elif f in (cls.USHORT4N, cls.FLOAT2, cls.SHORT4, cls.SHORT4N):
                return 8
            elif f == cls.FLOAT3:
                return 12
            elif f == cls.FLOAT4:
                return 16
            else:
                raise NotImplementedError()


    class Declaration:
        def __init__(self):
            self.usage = VertexFormat.USAGE.POSITION
            self.usage_index = 0
            self.format = VertexFormat.FORMAT.SHORT4
            self.offset = 0


class VertexBuffer(RCOL):
    TAG = 'VBUF'
    ID = 0x01D0E6FB

    class Buffer:
        def __init__(self):
            self.stream = None
            self.reader = None
            self.writer = None
            self.clear()

        def clear(self):
            self.stream = BytesIO()
            self.reader = StreamReader(self.stream)
            self.writer = StreamWriter(self.stream)

        def offset(self):
            return self.stream.tell()

        def delete_vertices(self, offset, vrtf, count):
            end_offset = offset + vrtf.stride * count
            self.stream.seek(end_offset, SEEK_SET)
            end_data = self.stream.read(-1)
            self.stream.seek(offset, SEEK_SET)
            self.stream.truncate()
            self.stream.writable(end_data)

        def read_vertices(self, offset, vrtf, count, uvscales):
            self.stream.seek(offset, SEEK_SET)
            a = []
            for i in range(count):
                actual = self.stream.tell()
                expected = offset + (i * vrtf.stride)
                assert actual == expected
                a.append(self.read_vertex(vrtf, uvscales))
            return a

        def write_vertices(self, vrtf, vertices, uvscales, posscales):
            self.stream.seek(0, SEEK_END)
            offset = self.stream.tell()
            for vertex in vertices:
                start = self.stream.tell()
                self.write_vertex(vrtf, vertex, uvscales, posscales)
                end = self.stream.tell()
                actual = end - start
                expected = vrtf.stride
                if not actual == expected:
                    assert actual == expected
            self.stream.flush()
            return offset

        def read_vertex(self, vrtf, uvscales):
            vertex = Vertex()
            start = self.stream.tell()
            end = start + vrtf.stride
            for declaration in vrtf.declarations:
                u = declaration.usage
                start_element = self.stream.tell()
                expected_element = start + declaration.offset + VertexFormat.FORMAT.byte_size(declaration.format)
                value = self.read_element(declaration, uvscales)
                actual_element = self.stream.tell()
                if not actual_element == expected_element:
                    assert actual_element == expected_element
                if u == VertexFormat.USAGE.POSITION:
                    vertex.position = value
                elif u == VertexFormat.USAGE.NORMAL:
                    vertex.normal = value
                elif u == VertexFormat.USAGE.UV:
                    if vertex.uv == None: vertex.uv = []
                    vertex.uv.append(value)
                elif u == VertexFormat.USAGE.BLEND_INDEX:
                    vertex.blend_indices = value
                elif u == VertexFormat.USAGE.BLEND_WEIGHT:
                    vertex.blend_weights = value
                elif u == VertexFormat.USAGE.COLOR:
                    vertex.colour = value
                elif u == VertexFormat.USAGE.TANGENT:
                    vertex.tangent = value
                else:
                    raise Exception("Unknown usage %s", declaration.usage)
            actual = self.stream.tell()
            if actual != end:
                assert actual == end
            return vertex

        def write_vertex(self, vrtf, v, uvscales, posscales):
            start = self.stream.tell()
            end = start + vrtf.stride
            for declaration in vrtf.declarations:
                u = declaration.usage
                if u == VertexFormat.USAGE.POSITION:
                    data = v.position
                elif u == VertexFormat.USAGE.NORMAL:
                    data = v.normal
                elif u == VertexFormat.USAGE.UV:
                    data = v.uv[declaration.usage_index]
                elif u == VertexFormat.USAGE.BLEND_INDEX:
                    data = v.blend_indices
                elif u == VertexFormat.USAGE.BLEND_WEIGHT:
                    data = v.blend_weights
                elif u == VertexFormat.USAGE.COLOR:
                    data = v.colour if v.colour else [1, 1, 1, 0]
                elif u == VertexFormat.USAGE.TANGENT:
                    data = v.tangent
                else:
                    raise Exception('Unknown VRTF usage type %i' % u)
                try:
                    self.write_element(declaration, data, uvscales, posscales)
                except Exception as e:
                    print('unable to write element: format=%s, usage=%s, data = %s, pos_scales=%s, uv_scales=%s' %
                          (declaration.format, declaration.usage, data, posscales, uvscales))
                    raise e
            actual = self.stream.tell()
            if actual != end:
                assert actual == end

        def write_element(self, declaration, value, uvscales, pos_scales):
            f = declaration.format
            u = declaration.usage
            float_count = VertexFormat.FORMAT.float_count(declaration.format)

            if f in (VertexFormat.FORMAT.FLOAT, VertexFormat.FORMAT.FLOAT2, VertexFormat.FORMAT.FLOAT3,
                     VertexFormat.FORMAT.FLOAT4):
                for val in value:
                    self.writer.f32(val)
            elif u == VertexFormat.USAGE.UV:
                if f == VertexFormat.FORMAT.SHORT2:
                    for i in range(float_count):
                        self.writer.i16(value[i] / uvscales[declaration.usage_index])
                elif f == VertexFormat.FORMAT.SHORT4 or f == VertexFormat.FORMAT.USHORT4N:
                    scalar = float(0x7FFF)
                    shorts = [math.floor(value[0] * scalar), math.floor(value[1] * scalar), 0, 0]
                    for short in shorts:
                        self.writer.i16(short)
            elif u in (VertexFormat.USAGE.BLEND_WEIGHT, VertexFormat.USAGE.COLOR):
                for i in range(4):
                    val = value[VertexFormat.FORMAT.UBYTE_MAP[i]] * 0xFF
                    self.writer.u8(min(val, 0xFF))
            elif f == VertexFormat.FORMAT.UBYTE4:
                for val in value:
                    self.writer.i8(val)
            elif f == VertexFormat.FORMAT.UBYTE4N:
                if u in (VertexFormat.USAGE.NORMAL, VertexFormat.USAGE.TANGENT):
                    if len(value) < 4:
                        value.append(1.0 if value[-1] < 0 else 0)
                    for val in value:
                        self.writer.i8(min(math.floor(
                            (float(0x7F) * val + float(0x80)) if val < 0 else float(0x7F) * val - float(0x80)), 127))
                        pass
            elif f == VertexFormat.FORMAT.SHORT2:
                for i in range(float_count): self.writer.i16(value[i] * 0xFFFF)
            elif f == VertexFormat.FORMAT.SHORT4 or f == VertexFormat.FORMAT.USHORT4N:
                scalar = math.floor(1.0 / pos_scales[declaration.usage_index])
                vals = [min(int(round(value[i] * scalar, 0)), 0x7FFF) for i in range(3)]
                for val in vals:
                    self.writer.i16(val)
                self.writer.u16(scalar)
            elif f == VertexFormat.FORMAT.UBYTE4 or f == VertexFormat.FORMAT.UBYTE4N:
                if not value:
                    print('no data for usage:%s format:%s' % (u, f))
                for val in value:
                    self.writer.i8(val)
            else:
                raise Exception("Unhandled format %s" % f)

        def read_element(self, declaration, uvscales):
            float_count = VertexFormat.FORMAT.float_count(declaration.format)
            value = [0.0] * float_count
            f = declaration.format
            u = declaration.usage
            b = []
            if f in (VertexFormat.FORMAT.FLOAT, VertexFormat.FORMAT.FLOAT2, VertexFormat.FORMAT.FLOAT3,
                     VertexFormat.FORMAT.FLOAT4):
                for i in range(float_count): value[i] = self.reader.f32()
            elif u == VertexFormat.USAGE.UV:
                if f == VertexFormat.FORMAT.SHORT2:
                    for i in range(float_count): value[i] = self.reader.i16() * uvscales[declaration.usage_index]
                elif f == VertexFormat.FORMAT.SHORT4:
                    shorts = [self.reader.i16() for i in range(4)]
                    assert shorts[2] == 0
                    value = [shorts[0] * uvscales[declaration.usage_index],
                             shorts[1] * uvscales[declaration.usage_index],
                             shorts[3] * uvscales[declaration.usage_index]]
            elif u in (VertexFormat.USAGE.BLEND_WEIGHT, VertexFormat.USAGE.COLOR):
                vals = [float(self.reader.u8()) for i in range(4)]
                for i in range(4):
                    val = vals[i]
                    val /= float(0xFF)
                    vals[VertexFormat.FORMAT.UBYTE_MAP[i]] = val
                value = vals
            elif f == VertexFormat.FORMAT.UBYTE4:
                for i in range(float_count):
                    value[i] = self.reader.i8()
            elif f == VertexFormat.FORMAT.UBYTE4N:
                if u in (VertexFormat.USAGE.NORMAL, VertexFormat.USAGE.TANGENT):
                    bytes = [self.reader.i8() for i in range(4)]
                    value = [(float(b) - float(0x80)) / float(0x7F) if b > 0 else (float(b) + float(0x80)) / float(0x7F)
                             for b in bytes]
                    pass
                else:
                    raise Exception("Unhandled usage %s for format %s" % (u, f))
            elif f == VertexFormat.FORMAT.SHORT2:
                for i in range(float_count): value[i] = self.reader.i16() / 0xFFFF
            elif f == VertexFormat.FORMAT.SHORT4:
                shorts = [self.reader.i16() for i in range(3)]
                scalar = self.reader.u16()
                vals = shorts
                vals.append(scalar)
                if not scalar: scalar = 0x3FFF
                for i in range(float_count): value[i] = float(shorts[i]) / float(scalar)
            elif f == VertexFormat.FORMAT.USHORT4N:
                shorts = [self.reader.i16() for i in range(3)]
                scalar = self.reader.u16()
                if not scalar: scalar = 511
                for i in range(float_count):
                    value[i] = shorts[i] / scalar
            elif f == VertexFormat.FORMAT.UBYTE4:
                value = [self.reader.i8() for i in range(4)]
            else:
                raise Exception("Unhandled format %s" % f)
            if u == VertexFormat.USAGE.BLEND_WEIGHT and max(value) > 1.0:
                pass
            return value

        def __del__(self):
            if self.stream != None:
                self.stream.close()

    class VERSION:
        DEFAULT = 0x00000101

    def __init__(self, key):
        RCOL.__init__(self, key)
        self.swizzle_info = SwizzleInfo(None)
        self.buffer = self.Buffer()

    def read_rcol(self, stream, rcol):
        s = StreamReader(stream)
        self.read_tag(stream)
        self.version = s.u32()
        assert s.u32() == 0
        self.swizzle_info = rcol.get_block(s.u32(), SwizzleInfo)
        start = stream.tell()
        stream.seek(0, SEEK_END)
        end = stream.tell()
        stream.seek(start, SEEK_SET)
        length = end - start
        self.buffer.stream.seek(0, SEEK_SET)
        self.buffer.stream.truncate()
        self.buffer.stream.write(stream.read(length))
        self.buffer.stream.seek(0, SEEK_SET)

    def write_rcol(self, stream, rcol):
        s = StreamWriter(stream)
        self.write_tag(stream)
        s.u32(self.version)
        s.u32(0)
        s.u32(rcol.get_block_index(self.swizzle_info))
        self.buffer.stream.seek(0, SEEK_SET)
        stream.write(self.buffer.stream.read())


class Swizzle:
    SWIZZLE_32 = 0x00000001
    SWIZZLE_16x2 = 0x00000002


class SwizzleInfo(RCOL):
    ID = 0x00000000

    class VERSION:
        STANDARD = 0x00000101


    class Segment(Serializable):
        def __init__(self, stream=None):
            self.vertex_size = 0
            self.vertex_count = 0
            self.byte_offset = 0
            self.commands = []
            Serializable.__init__(self, stream)

        def read(self, stream, resource=None):
            s = StreamReader(stream)
            self.vertex_size = s.u32()
            self.vertex_count = s.u32()
            self.byte_offset = s.u32()
            self.commands = [s.u32() for cmd_index in range(int(self.vertex_size / 4))]

        def write(self, stream, resource=None):
            s = StreamWriter(stream)
            s.u32(self.vertex_size)
            s.u32(self.vertex_count)
            s.u32(self.byte_offset)
            assert len(self.commands) == (int(self.vertex_size / 4))
            for cmd in self.commands: s.u32(cmd)

        @staticmethod
        def from_mesh(mesh):
            vrtf = mesh.vertex_format
            segment = SwizzleInfo.Segment()
            segment.vertex_size = mesh.vrtf.stride
            segment.byte_offset = mesh.stream_offset
            segment.vertex_count = mesh.vertex_count
            for d in vrtf.declarations:
                if d.format == VertexFormat.FORMAT.FLOAT4:
                    segment.commands.extend(
                        [Swizzle.SWIZZLE_32, Swizzle.SWIZZLE_32, Swizzle.SWIZZLE_32,
                         Swizzle.SWIZZLE_32])
                elif d.format == VertexFormat.FORMAT.FLOAT3:
                    segment.commands.extend([Swizzle.SWIZZLE_32, Swizzle.SWIZZLE_32,
                                             Swizzle.SWIZZLE_32])
                elif d.format == VertexFormat.FORMAT.FLOAT2:
                    segment.commands.extend([Swizzle.SWIZZLE_32, Swizzle.SWIZZLE_32])
                elif d.format in (
                        VertexFormat.FORMAT.FLOAT, VertexFormat.FORMAT.UBYTE4, VertexFormat.FORMAT.COLOR_UBYTE4,
                        VertexFormat.FORMAT.UBYTE4N, VertexFormat.FORMAT.DEC3N,
                        VertexFormat.FORMAT.UDEC3N):
                    segment.commands.append(Swizzle.SWIZZLE_32)
                elif d.format in (VertexFormat.FORMAT.SHORT2, VertexFormat.FORMAT.SHORT2N, VertexFormat.FORMAT.USHORT2N,
                                  VertexFormat.FORMAT.FLOAT16_2):
                    segment.commands.append(Swizzle.SWIZZLE_16x2)
                elif d.format in (
                        VertexFormat.FORMAT.SHORT4, VertexFormat.FORMAT.SHORT4N, VertexFormat.FORMAT.FLOAT16_4):
                    segment.commands.extend([Swizzle.SWIZZLE_16x2, Swizzle.SWIZZLE_16x2])
            return segment

    def __init__(self, key=None):
        RCOL.__init__(self, key)
        self.segments = []

    def read_rcol(self, stream, rcol):
        s = StreamReader(stream)
        self.segments = [self.Segment(stream) for i in range(s.i32())]

    def write_rcol(self, stream, rcol):
        s = StreamWriter(stream)
        s.i32(len(self.segments))
        for segment in self.segments: segment.write(stream)

    def add_mesh(self, mesh):
        self.segments.append(self.Segment.from_mesh(mesh))

    def clear(self):
        self.segments = []

    def __eq__(self, other):
        return object.__eq__(self, other)

    def __hash__(self):
        return object.__hash__(self)


class IndexBufferShadow(IndexBuffer):
    ID = 0x0229684F

    def __eq__(self, other):
        return object.__eq__(self, other)

    def __hash__(self):
        return object.__hash__(self)


class VertexBufferShadow(VertexBuffer):
    ID = 0x0229684B

    def __eq__(self, other):
        return object.__eq__(self, other)

    def __hash__(self):
        return object.__hash__(self)


class Model(RCOL):
    TAG = 'MODL'
    ID = 0x01661233

    class VERSION():
        STANDARD = 0x00000100
        EXTENDED = 0x00000102

    class LOD(Serializable):
        class FLAGS:
            NONE = 0x00000000
            PORTAL = 0x00000001
            DOOR = 0x00000002

        def __init__(self, stream=None, rcol=None):
            self.model = None
            self.flags = Model.LOD.FLAGS.NONE
            self.id = LOD_ID.MEDIUM_DETAIL
            self.min_z = 0.0
            self.max_z = 0.0
            Serializable.__init__(self, stream, rcol)

        def read(self, stream, resources):
            s = StreamReader(stream)
            self.model = resources.get_block(s.u32(), ModelLod)
            self.flags = s.u32()
            self.id = s.u16()
            self.is_sunshadow = bool(s.u16())
            self.min_z = s.f32()
            self.max_z = s.f32()

        def write(self, stream, resources):
            s = StreamWriter(stream)
            s.u32(resources.get_block_index(self.model, RCOL.Reference.PUBLIC))
            s.u32(self.flags)
            s.u16(self.id)
            s.u16(int(self.is_sunshadow))
            s.f32(self.min_z)
            s.f32(self.max_z)


    def __init__(self, key, stream=None, rcol=None):
        self.version = self.VERSION.STANDARD
        self.bounds = BoundingBox()
        self.extra_bounds = []
        self.fade_type = 0
        self.custom_fade_distance = 0.0
        self.lods = []
        RCOL.__init__(self, key, stream)


    def read_rcol(self, stream, rcol):
        s = StreamReader(stream)
        self.read_tag(stream)
        self.version = s.u32()
        cLods = s.i32()
        self.bounds.read(stream)
        if self.version >= self.VERSION.EXTENDED:
            self.extra_bounds = [BoundingBox(stream=stream) for i in range(s.i32())]
            self.fade_type = s.u32()
            self.custom_fade_distance = s.f32()
        self.lods = [self.LOD(stream, rcol) for i in range(cLods)]

    def write_rcol(self, stream, rcol):
        s = StreamWriter(stream)
        self.write_tag(stream)
        s.u32(self.version)
        s.i32(len(self.lods))
        self.bounds.write(stream)
        if self.version >= self.VERSION.EXTENDED:
            s.i32(len(self.extra_bounds))
            for extra in self.extra_bounds:
                extra.write(stream)
            s.u32(self.fade_type)
            s.f32(self.custom_fade_distance)
        for lod in self.lods:
            lod.write_rcol(stream, rcol)


class ObjectSkinController(RCOL, SkinController):
    TAG = 'SKIN'
    ID = 0x01D0E76B

    def __init__(self, key=None):
        RCOL.__init__(self, key)
        SkinController.__init__(self)
        self.version = 0

    def __eq__(self, other):
        return object.__eq__(self, other)

    def __hash__(self):
        return object.__hash__(self)

    def read_rcol(self, stream, rcol):
        self.read_tag(stream)
        s = StreamReader(stream)
        self.version = s.u32()
        cBones = s.i32()
        names = [s.u32() for i in range(cBones)]
        poses = [s.m43() for pose_index in range(cBones)]
        self.bones = [self.Bone(names[i], poses[i]) for i in range(cBones)]

    def write_rcol(self, stream, rcol):
        self.write_tag(stream)
        s = StreamWriter(stream)
        s.u32(self.version)
        cBones = len(self.bones)
        s.i32(cBones)
        for i in range(cBones): s.hash(self.bones[i].name)
        for bone in self.bones:
            s.m43(bone.inverse_bind_pose)

